Skip to content

[JAX] GEMM tex and FFI cleanup#2739

Open
phu0ngng wants to merge 9 commits intoNVIDIA:mainfrom
phu0ngng:gemm_cleanup
Open

[JAX] GEMM tex and FFI cleanup#2739
phu0ngng wants to merge 9 commits intoNVIDIA:mainfrom
phu0ngng:gemm_cleanup

Conversation

@phu0ngng
Copy link
Collaborator

@phu0ngng phu0ngng commented Mar 5, 2026

Description

This PR removes unused, untested, and partially supported features from the public GEMM primitive: fused GeLU, bias gradient, and grad mode — these were dead code paths not exercised by any JAX-side caller.

Besides, the PR also removes all the FP8_2X_ACC_XGRAD from the QuantizeConfig as it is no longer inferable from the recipes. Users can set the precision via the new env variable TE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION instead.

Change details:

  • IntroduceGemmV2FFI replacing the old GemmFFI, remove untested/unused boolean flags (fuse_bias, fuse_gelu, grad), and consolidate other individual attributes into a GemmConfig struct.
  • Bias fusion is now inferred from bias.size > 0 rather than a separate fuse_bias flag
  • Old GemmFFI is kept as a deprecated shim (warns once via std::call_once, scheduled for removal September 2026) to avoid breaking existing users
  • TE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION env-var is introduced to set the precision of the accumulation in MatMul, i.e., whether to promote to high dtype for storing the intermediate accumulation result.
  • Move assert_cublas_requirements checks from lowering to abstract, providing earlier shape validation
  • In the test_distributed_dense.py, output sharding constraint is added to ensure the correct sharding pattern for the input gradients in the bprop.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

phu0ngng added 3 commits March 5, 2026 09:29
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 5, 2026

Greptile Summary

This PR is a well-motivated refactoring of the JAX GEMM stack: it removes dead code paths (fused GeLU, bias gradient, grad mode), consolidates the C++ FFI parameters into a GemmConfig struct, introduces GemmV2FFI/GemmInitV2FFI, and replaces per-recipe FP8_2X_ACC_* fields with a single NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION env var. Previous review issues (missing return in deprecated shim, non-thread-safe warned_* flags, use_split_accumulator dropped on JAX fallback path, _sharded_impl tuple mismatch) are all correctly addressed in this version.

Key verified findings:

  • Return type annotation: gemm() is annotated as returning Tuple[jnp.ndarray, ...] but actually returns a bare jax.Array. This should be corrected for type safety.
  • Comment documentation: The comment explaining use_split_accumulator lacks the environment variable name (NVTE_FP8_GEMM_HIGH_PRECISION_ACCUMULATION) and has a minor grammar issue.
  • Breaking positional API change: bias was inserted as the 3rd positional parameter, shifting contracting_dims from index 2 to index 3. External code using positional syntax for contracting_dims will silently break, but the PR does not mark this as a breaking change. Internal callers safely use keyword arguments.

The core logic is sound and previous review issues have been properly addressed. The identified issues are fixable improvements to API documentation and type safety.

Confidence Score: 4/5

  • Safe to merge with minor documentation and type annotation fixes. The core refactoring logic is sound and well-executed.
  • The PR correctly addresses previous review concerns (missing return statements, thread-safety, environment variable handling, tuple unpacking). The identified issues are fixable improvements to documentation and type safety, not functional bugs. Internal callers use keyword arguments so positional API change doesn't impact them. The return type annotation mismatch and missing env var documentation are correctable without changing behavior.
  • transformer_engine/jax/cpp_extensions/gemm.py (3 minor issues: return type annotation, comment documentation, positional API documentation)

Comments Outside Diff (1)

  1. transformer_engine/jax/cpp_extensions/gemm.py, line 1717 (link)

    The return type annotation does not match the actual return type. The function returns a bare jax.Array, but the annotation says Tuple[jnp.ndarray, ...].

    Looking at the function body:

    • Line 1769: return output (bare array for JAX fallback path)
    • Line 1783: return output (bare array after unpacking from _te_gemm)

    The docstring correctly documents this as returning a single jax.Array (line 1742), but the type annotation should match.

Last reviewed commit: 04d2b92

phu0ngng and others added 4 commits March 5, 2026 09:53
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 5, 2026

/te-ci JAX L1

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall LGTM, thanks for this PR! Left some small comments and questions

std::vector<size_t> buffer_shape{1, 1};
auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(buffer_shape, DType::kFloat32,
JAXX_Collective_Op::ALL_GATHER);
[[maybe_unused]] auto _ = CollectiveGemmPlanRegistry::getInstance().get_executor(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This wasn't introduced in this PR, but why do we need to have an auto _ =? Would it not work to just call CollectiveGemmPlanRegistry::getInstance().get_executor?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will work without an unused return in modern C++, but I think it is not a good practice.

transformer_engine::jax::GemmConfig,
::xla::ffi::StructMember<transformer_engine::jax::JAXX_Scaling_Mode>("scaling_mode"),
::xla::ffi::StructMember<transformer_engine::jax::JAXX_Collective_Op>("collective_op"),
::xla::ffi::StructMember<int64_t>("lhs_axis_boundary"),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea, this struct-based approach lets us better set default values for newly added fields, right? So it's easier to be backwards compatible with older HLO but still have the flexibility to add new attributes as long as the default value keeps the same behavior as before. I recall doing something similar for attention

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The one that we have in attention is slightly different. There, we don't define a struct, and XLA automatically decodes the attribute, but rather dynamically queries the attribute at runtime, which is harder to catch debug

size_t input_batch = get_attr_value<int64_t>(attrs, "input_batch"); \
size_t bias_batch = get_attr_value<int64_t>(attrs, "bias_batch"); \
size_t q_max_seqlen = get_attr_value<int64_t>(attrs, "q_max_seqlen"); \
size_t kv_max_seqlen = get_attr_value<int64_t>(attrs, "kv_max_seqlen"); \
size_t attn_heads = get_attr_value<int64_t>(attrs, "attn_heads"); \
size_t num_gqa_groups = get_attr_value<int64_t>(attrs, "num_gqa_groups"); \
size_t bias_heads = get_attr_value<int64_t>(attrs, "bias_heads"); \
size_t qk_head_dim = get_attr_value<int64_t>(attrs, "qk_head_dim"); \
size_t v_head_dim = get_attr_value<int64_t>(attrs, "v_head_dim"); \
size_t max_segments_per_seq = get_attr_value<int64_t>(attrs, "max_segments_per_seq"); \
.

I think this struct approach is better, it's flexible enough so that we don't need to introduce a new API whenever we want to add an attribute, it should also allow an optional attribute, i.e. struct variable with a default value, and also less bug prone.

Signed-off-by: Phuong Nguyen <phuonguyen@nvidia.com>
@phu0ngng
Copy link
Collaborator Author

phu0ngng commented Mar 6, 2026

/te-ci JAX L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants